{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Bringing contextual word representations into your models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "__author__ = \"Christopher Potts\"\n",
    "__version__ = \"CS224u, Stanford, Spring 2019\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Contents\n",
    "\n",
    "1. [Overview](#Overview)\n",
    "1. [Set-up](#Set-up)\n",
    "  1. [BERT](#BERT)\n",
    "  1. [ELMo](#ELMo)\n",
    "  1. [General](#General)\n",
    "1. [Using BERT](#Using-BERT)\n",
    "  1. [BERT representations for the SST](#BERT-representations-for-the-SST)\n",
    "  1. [BERT sentence-level classifier](#BERT-sentence-level-classifier)\n",
    "  1. [Using the SST experimental framework with BERT](#Using-the-SST-experimental-framework-with-BERT)\n",
    "  1. [BERT word-level representations as RNN features](#BERT-word-level-representations-as-RNN-features)\n",
    "1. [Using ELMo](#Using-ELMo)\n",
    "  1. [ELMo representations for the SST](#ELMo-representations-for-the-SST)\n",
    "  1. [ELMo representations as RNN features](#ELMo-representations-as-RNN-features)\n",
    "  1. [Using the SST experiment framework with ELMo](#Using-the-SST-experiment-framework-with-ELMo)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Overview\n",
    "\n",
    "This notebook provides a basic introduction to using pre-trained [BERT](https://github.com/google-research/bert) and [ELMo](https://allennlp.org/elmo) representations. It is meant as a practical companion to our lecture on contextual word representations. The goal of this notebook is just to help you use these representations in your own work. The BERT and ELMo teams have done amazing work to make these resources available to the community. Many projects can benefit from them, so it is probably worth your time to experiment.\n",
    "\n",
    "This notebook should be considered an experimental extension to the regular course materials. It has some special requirements – libraries and data files – that are not part of the core requirements for this repository. All these tools are very new and being updated frequently, so you might need to do some fiddling to get all of this to work. As I said, though, it's probably worth the effort!"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Set-up"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### BERT\n",
    "\n",
    "Han Xiao's \"BERT as a Service\" is pretty incredible:\n",
    "\n",
    "https://github.com/hanxiao/bert-as-service\n",
    "\n",
    "To make use of it, run these two pip installs in your usual course virtual environment:\n",
    "\n",
    "```sh\n",
    "pip install bert-serving-server\n",
    "pip install bert-serving-client\n",
    "```\n",
    "\n",
    "After that, you just need to download a BERT model:\n",
    "\n",
    "https://github.com/google-research/bert#pre-trained-models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "from bert_serving.client import BertClient "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Edit the following command by replacing\n",
    "\n",
    "```data/bert/uncased_L-12_H-768_A-12/```\n",
    "\n",
    "with the path to your downloaded BERT model directory, and then run the command in a Terminal window:"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "```sh\n",
    "bert-serving-start -model_dir data/bert/uncased_L-12_H-768_A-12/ -pooling_strategy NONE -max_seq_len NONE -show_tokens_to_client\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### ELMo\n",
    "\n",
    "There are a number of ways to use pre-trained ELMo models. We'll use the simplest of the AllenNLP interfaces. Run the following to install [AllenNLP](https://allennlp.org):\n",
    "\n",
    "```sh\n",
    "pip install allennlp\n",
    "```\n",
    "\n",
    "Mac users: If your installantion fails, make sure your Xcode tools are up to date by running `xcode-select --install`. This is a common source of problems installing AllenNLP at present.\n",
    "\n",
    "We'll use the `ElmoEmbedder` interface, which downloads a default model. See below for instructions on how to use a different model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex.\n"
     ]
    }
   ],
   "source": [
    "from allennlp.commands.elmo import ElmoEmbedder\n",
    "from nltk.tokenize.treebank import TreebankWordTokenizer"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### General\n",
    "\n",
    "The following are requirements that you'll already have met if you've been working in this repository. As you can see, we'll use the [Stanford Sentiment Treebank](sst_01_overview.ipynb) for illustrations, and we'll try out a few different deep learning models."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sst\n",
    "from torch_shallow_neural_classifier import TorchShallowNeuralClassifier\n",
    "from torch_rnn_classifier import TorchRNNClassifier\n",
    "from sklearn.metrics import classification_report"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "SST_HOME = os.path.join(\"data\", \"trees\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Using BERT"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### BERT representations for the SST"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "With the BERT server running in the background, the following will allow you to process new examples and obtain their BERT representations:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "bc = BertClient(check_length=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Here we load in the SST train and dev sets, and we flatten the trees into strings of just their leaf nodes. We'll allow BERT to tokenize for us; an alternative is to use `is_tokenized=True` in the call to `bc.encode`, but this [requires a bit more fussing with the representations](https://github.com/hanxiao/bert-as-service#using-your-own-tokenizer) and might be suboptimal."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "sst_train_reader = sst.train_reader(\n",
    "    SST_HOME, class_func=sst.ternary_class_func)\n",
    "\n",
    "sst_train = [(\" \".join(t.leaves()), label) for t, label in sst_train_reader]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "sst_dev_reader = sst.dev_reader(\n",
    "    SST_HOME, class_func=sst.ternary_class_func)\n",
    "\n",
    "sst_dev = [(\" \".join(t.leaves()), label) for t, label in sst_dev_reader]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "X_str_train, y_train = zip(*sst_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "X_str_dev, y_dev = zip(*sst_dev)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now we process the examples into BERT representations. I've set `show_tokens=True` to help us keep track of what BERT is doing to our texts:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "X_bert_train, bert_train_toks = bc.encode(\n",
    "    list(X_str_train), show_tokens=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "X_bert_dev, bert_dev_toks = bc.encode(\n",
    "    list(X_str_dev), show_tokens=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### BERT sentence-level classifier"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "As first illustration, we'll use BERT representations as the input to a classifier model. The first step is to combine the individual word representations into fixed dimensional vectors, so that we can use them as inputs to a classifier. For this, I'll just average the individual vectors:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "def bert_reduce_mean(X):\n",
    "    return X.mean(axis=1)  "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This is very much like what we [summed the GloVe representations of these examples](sst_03_neural_networks.ipynb#Distributed-representations-as-features), but now the individual word representations are different depending on the context in which they appear.\n",
    "\n",
    "Note: If you start the BERT server with `-pooling_strategy REDUCE_MEAN`, then this step is done for you. And [see here for discussion of other pooling strategies](https://github.com/hanxiao/bert-as-service#q-what-are-the-available-pooling-strategies)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "X_bert_train_mean = bert_reduce_mean(X_bert_train)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "BERT representations are pretty large:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "768"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "X_bert_train_mean.shape[1]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now we instantiate and fit a classifier. I picked a `TorchShallowNeuralClassifier`. Since the input representations are large, I chose a pretty large `hidden_dim`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "mod = TorchShallowNeuralClassifier(\n",
    "    max_iter=100, hidden_dim=300)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Finished epoch 100 of 100; error is 0.17165078409016132"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 2min 18s, sys: 595 ms, total: 2min 18s\n",
      "Wall time: 20.3 s\n"
     ]
    }
   ],
   "source": [
    "%time _ = mod.fit(X_bert_train_mean, y_train)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Evaluation proceeds as you would expect:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "X_bert_dev_mean = bert_reduce_mean(X_bert_dev)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "bert_sent_preds = mod.predict(X_bert_dev_mean)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "              precision    recall  f1-score   support\n",
      "\n",
      "    negative      0.713     0.673     0.692       428\n",
      "     neutral      0.348     0.314     0.330       229\n",
      "    positive      0.714     0.788     0.749       444\n",
      "\n",
      "   micro avg      0.645     0.645     0.645      1101\n",
      "   macro avg      0.592     0.592     0.591      1101\n",
      "weighted avg      0.638     0.645     0.640      1101\n",
      "\n"
     ]
    }
   ],
   "source": [
    "print(classification_report(y_dev, bert_sent_preds, digits=3))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Using the SST experimental framework with BERT\n",
    "\n",
    "It is straightforward to conduct experiments like the above using `sst.experiment`, which will enable you to do a wider range of experiments without writing or copy-pasting a lot of code. \n",
    "\n",
    "Per [the guidelines at Han Xiao's \"BERT as a service\"](https://github.com/hanxiao/bert-as-service#speed-wrt-client_batch_size), it would be prohibitively slow to call `bc.encode` on all our sentences individually. To address this, I suggest first creating a look-up for the precomputed BERT representations and then having your feature function simply use this look-up:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "bert_lookup = {}\n",
    "\n",
    "for (sents, reps) in ((X_str_train, X_bert_train_mean), \n",
    "                      (X_str_dev, X_bert_dev_mean)):\n",
    "    assert len(sents) == len(reps)\n",
    "    for s, rep in zip(sents, reps):\n",
    "        bert_lookup[s] = rep"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "def bert_sentence_phi(tree):\n",
    "    s = \" \".join(tree.leaves())\n",
    "    return bert_lookup[s]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "def fit_wide_shallow_network(X, y):\n",
    "    mod = TorchShallowNeuralClassifier(\n",
    "        max_iter=100, hidden_dim=300)\n",
    "    mod.fit(X, y)\n",
    "    return mod"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Finished epoch 100 of 100; error is 0.16109364107251167"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "              precision    recall  f1-score   support\n",
      "\n",
      "    negative      0.680     0.736     0.707       428\n",
      "     neutral      0.299     0.192     0.234       229\n",
      "    positive      0.703     0.777     0.738       444\n",
      "\n",
      "   micro avg      0.639     0.639     0.639      1101\n",
      "   macro avg      0.561     0.568     0.560      1101\n",
      "weighted avg      0.610     0.639     0.621      1101\n",
      "\n",
      "CPU times: user 2min 18s, sys: 481 ms, total: 2min 18s\n",
      "Wall time: 21.5 s\n"
     ]
    }
   ],
   "source": [
    "%%time \n",
    "_ = sst.experiment(\n",
    "    SST_HOME,\n",
    "    bert_sentence_phi,\n",
    "    fit_wide_shallow_network,\n",
    "    train_reader=sst.train_reader, \n",
    "    assess_reader=sst.dev_reader, \n",
    "    class_func=sst.ternary_class_func,\n",
    "    vectorize=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### BERT word-level representations as RNN features\n",
    "\n",
    "We can also use BERT representations as the input to an RNN. There is just one key change from how we used these models before:\n",
    "\n",
    "* Previously, we would feed in lists of tokens, and they would be converted to indices into a fixed embedding space. This presumes that all words have the same representation no matter what their context is. \n",
    "\n",
    "* With BERT, we skip the embedding entirely and just feed in lists of BERT vectors, which means that the same word can be represented in different ways.\n",
    "\n",
    "`TorchRNNClassifier` supports this via `use_embedding=False`. In turn, you needn't supply a vocabulary:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "bert_rnn = TorchRNNClassifier(\n",
    "    vocab=[],\n",
    "    max_iter=50,\n",
    "    use_embedding=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Finished epoch 50 of 50; error is 3.3966610431671143"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 31min 6s, sys: 2min 27s, total: 33min 34s\n",
      "Wall time: 10min 11s\n"
     ]
    }
   ],
   "source": [
    "%time _ = bert_rnn.fit(X_bert_train, y_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [],
   "source": [
    "bert_rnn_preds = bert_rnn.predict(X_bert_dev)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "              precision    recall  f1-score   support\n",
      "\n",
      "    negative      0.778     0.614     0.687       428\n",
      "     neutral      0.308     0.341     0.324       229\n",
      "    positive      0.727     0.836     0.778       444\n",
      "\n",
      "   micro avg      0.647     0.647     0.647      1101\n",
      "   macro avg      0.605     0.597     0.596      1101\n",
      "weighted avg      0.660     0.647     0.648      1101\n",
      "\n"
     ]
    }
   ],
   "source": [
    "print(classification_report(y_dev, bert_rnn_preds, digits=3))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Using ELMo\n",
    "\n",
    "Using ELMo is very similar to using BERT. I'll illustrate just with an RNN."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### ELMo representations for the SST"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "When first run, the following command downloads \n",
    "\n",
    "```\n",
    "elmo_2x4096_512_2048cnn_2xhighway_options.json \n",
    "elmo_2x4096_512_2048cnn_2xhighway_weights.hdf5\n",
    "```\n",
    "\n",
    "directly from S3 to a local temp directory.  Use `options_file` and `weight_file` to ask `ElmoEmbedder` to use a specified pair of model files. For additional details:\n",
    "\n",
    "https://github.com/allenai/allennlp/blob/master/allennlp/commands/elmo.py"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [],
   "source": [
    "elmo = ElmoEmbedder()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The ELMo interface requires tokenized input. I believe the following tokenizer matches the behavior of the one used by the team to create the representations:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenizer = TreebankWordTokenizer()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [],
   "source": [
    "elmo_train_toks = [tokenizer.tokenize(ex) for ex in X_str_train]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [],
   "source": [
    "elmo_dev_toks = [tokenizer.tokenize(ex) for ex in X_str_dev]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Here we create the representations for the train and dev sets:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [],
   "source": [
    "X_elmo_train_layers = list(elmo.embed_sentences(elmo_train_toks))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [],
   "source": [
    "X_elmo_dev_layers = list(elmo.embed_sentences(elmo_dev_toks))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`X_elmo_train_layers` has three dimensions:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(3, 13, 1024)"
      ]
     },
     "execution_count": 38,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "X_elmo_dev_layers[0].shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "For each word (second dimension), there are three layers of length 1024. So ELMo representations are even larger than BERT ones!"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### ELMo representations as RNN features"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "There are many ways we could combine the layers available for each word. Here, I'll use the mean:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [],
   "source": [
    "def elmo_layer_reduce_mean(elmo_vecs):\n",
    "    return [ex.mean(axis=0) for ex in elmo_vecs]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [],
   "source": [
    "X_elmo_train = elmo_layer_reduce_mean(X_elmo_train_layers)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now we can fit an RNN as usual:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [],
   "source": [
    "elmo_rnn = TorchRNNClassifier(\n",
    "    vocab=[],\n",
    "    max_iter=50,\n",
    "    use_embedding=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Finished epoch 50 of 50; error is 0.09299760637804866"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 16min 2s, sys: 1min 32s, total: 17min 34s\n",
      "Wall time: 4min 42s\n"
     ]
    }
   ],
   "source": [
    "%time _ = elmo_rnn.fit(X_elmo_train, y_train)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Evaluation proceeds in the usual way:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [],
   "source": [
    "X_elmo_dev = elmo_layer_reduce_mean(X_elmo_dev_layers)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [],
   "source": [
    "elmo_rnn_preds = elmo_rnn.predict(X_elmo_dev)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "              precision    recall  f1-score   support\n",
      "\n",
      "    negative      0.706     0.678     0.691       428\n",
      "     neutral      0.292     0.258     0.274       229\n",
      "    positive      0.734     0.806     0.768       444\n",
      "\n",
      "   micro avg      0.642     0.642     0.642      1101\n",
      "   macro avg      0.577     0.581     0.578      1101\n",
      "weighted avg      0.631     0.642     0.635      1101\n",
      "\n"
     ]
    }
   ],
   "source": [
    "print(classification_report(y_dev, elmo_rnn_preds, digits=3))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Using the SST experiment framework with ELMo\n",
    "\n",
    "To round things out, here's an example of how to use `sst.experiment` with ELMo, for more compact and maintainable experimental code:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {},
   "outputs": [],
   "source": [
    "def elmo_sentence_phi(tree):\n",
    "    vecs = elmo.embed_sentence(tree.leaves())\n",
    "    return vecs.mean(axis=0) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "metadata": {},
   "outputs": [],
   "source": [
    "def fit_elmo_rnn(X, y):\n",
    "    mod = TorchRNNClassifier(\n",
    "        vocab=[],\n",
    "        max_iter=50,\n",
    "        use_embedding=False)\n",
    "    mod.fit(X, y)\n",
    "    return mod"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Finished epoch 50 of 50; error is 0.021800976479426026"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "              precision    recall  f1-score   support\n",
      "\n",
      "    negative      0.707     0.715     0.711       428\n",
      "     neutral      0.344     0.245     0.286       229\n",
      "    positive      0.733     0.833     0.780       444\n",
      "\n",
      "   micro avg      0.665     0.665     0.665      1101\n",
      "   macro avg      0.594     0.598     0.592      1101\n",
      "weighted avg      0.642     0.665     0.650      1101\n",
      "\n",
      "CPU times: user 3h 22min 56s, sys: 2min 25s, total: 3h 25min 22s\n",
      "Wall time: 51min 40s\n"
     ]
    }
   ],
   "source": [
    "%%time \n",
    "_ = sst.experiment(\n",
    "    SST_HOME,\n",
    "    elmo_sentence_phi,\n",
    "    fit_elmo_rnn,\n",
    "    train_reader=sst.train_reader, \n",
    "    assess_reader=sst.dev_reader, \n",
    "    class_func=sst.ternary_class_func,\n",
    "    vectorize=False)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
